{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sentence Similarity with Pretrained BERT\n",
    "In this notebook, we use a pretrained [BERT model](https://arxiv.org/abs/1810.04805) as a sentence encoder to measure sentence similarity. We use a [feature extractor](../../utils_nlp/bert/extract_features.py) that wraps [Hugging Face's PyTorch implementation](https://github.com/huggingface/pytorch-pretrained-BERT) of Google's [BERT](https://github.com/google-research/bert). \n",
    "\n",
    "**Note: To learn how to do pre-training on your own, please reference the [AzureML-BERT repo](https://github.com/microsoft/AzureML-BERT) created by Microsoft.**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 00 Global Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import torch\n",
    "import itertools\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scrapbook as sb\n",
    "from collections import OrderedDict\n",
    "\n",
    "sys.path.append(\"../../\")\n",
    "from utils_nlp.models.bert.common import Language, Tokenizer\n",
    "from utils_nlp.models.bert.sequence_encoding import BERTSentenceEncoder, PoolingStrategy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# device config\n",
    "NUM_GPUS = 0\n",
    "\n",
    "# model config\n",
    "LANGUAGE = Language.ENGLISH\n",
    "TO_LOWER = True\n",
    "MAX_SEQ_LENGTH = 128\n",
    "LAYER_INDEX = -2\n",
    "POOLING_STRATEGY = PoolingStrategy.MEAN\n",
    "\n",
    "# path config\n",
    "CACHE_DIR = \"./temp\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists(CACHE_DIR):\n",
    "    os.makedirs(CACHE_DIR, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 01 Define the Sentence Encoder with Pretrained BERT"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `BERTSentenceEncoder` defaults to Pretrained BERT."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 407873900/407873900 [00:15<00:00, 26602678.27B/s]\n",
      "100%|██████████| 231508/231508 [00:00<00:00, 905295.88B/s]\n"
     ]
    }
   ],
   "source": [
    "se = BERTSentenceEncoder(\n",
    "    language=LANGUAGE,\n",
    "    num_gpus=NUM_GPUS,\n",
    "    cache_dir=CACHE_DIR,\n",
    "    to_lower=TO_LOWER,\n",
    "    max_len=MAX_SEQ_LENGTH,\n",
    "    layer_index=LAYER_INDEX,\n",
    "    pooling_strategy=POOLING_STRATEGY,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 02 Compute the Sentence Encodings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `encode` method of the sentence encoder accepts a list of text to encode, as well as the layers we want to extract the embeddings from and the pooling strategy we want to use. The embedding size is 768. We can also return just the values column as a list of numpy arrays by setting the `as_numpy` parameter to True."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:00<00:00, 2917.78it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>text_index</th>\n",
       "      <th>layer_index</th>\n",
       "      <th>values</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>-2</td>\n",
       "      <td>[0.038080588, 0.0926698, 0.0366186, -0.1218368...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>-2</td>\n",
       "      <td>[0.084241375, 0.099506006, -0.38437817, 0.2164...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   text_index  layer_index                                             values\n",
       "0           0           -2  [0.038080588, 0.0926698, 0.0366186, -0.1218368...\n",
       "1           1           -2  [0.084241375, 0.099506006, -0.38437817, 0.2164..."
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result = se.encode(\n",
    "    [\"Coffee is good\", \"The moose is across the street\"],\n",
    "    as_numpy=False\n",
    ")\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/scrapbook.scrap.json+json": {
       "data": 768,
       "encoder": "json",
       "name": "result",
       "version": 1
      }
     },
     "metadata": {
      "scrapbook": {
       "data": true,
       "display": false,
       "name": "result"
      }
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# for testing\n",
    "size_emb = len(result[\"values\"].iloc[0])\n",
    "sb.glue(\"size_emb\", size_emb)\n"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "Python (nlp_gpu)",
   "language": "python",
   "name": "nlp_gpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
